Cashew Disease Classification Using Convolutional Neural Networks (CNN)¶
Project Members¶
Ayush Panchal - P24DS013¶
Avanti Thale - P24DS007¶
Pooja Dave - P24DS012¶
In this project, we will implement a machine learning model using Convolutional Neural Networks (CNN) to classify diseases in cashew trees based on images of their leaves. The goal is to automate the process of detecting diseases, which can help farmers manage crop health and improve productivity.
Objectives¶
- Preprocess and prepare a dataset of images containing cashew leaves affected by various diseases.
- Implement a CNN model to learn patterns in the images and classify them into different categories.
- Evaluate the model's performance using metrics such as accuracy, precision, recall, and F1-score.
Key Steps¶
- Data Collection: We will gather images of healthy and diseased cashew leaves. These images will serve as our dataset.
- Data Preprocessing: Resize, normalize, and augment the images to prepare them for training.
- Model Building: Implement using Keras/TensorFlow to classify the images.
- Training: Train the CNN model on the preprocessed dataset.
- Evaluation: Evaluate the model's performance using test data and metrics.
- Prediction: Test the model with new, unseen images to predict disease classification.
By the end of this project, we aim to develop a robust model that can assist in early disease detection for cashew trees, ultimately aiding in better crop management and yield.
Let's get started by loading the dataset and performing some initial exploration!
1. Setup and importing necessary libraries¶
1.1 Ignoring warnings shown in the notebook¶
import warnings
warnings.filterwarnings('ignore')
1.2 Installing gdown in the environment¶
Gdown : Used to download a public file/folder from Google Drive. Gdown provides what curl/wget doesn't for Google Drive.
!pip install gdown
Requirement already satisfied: gdown in /opt/conda/lib/python3.10/site-packages (5.2.0) Requirement already satisfied: beautifulsoup4 in /opt/conda/lib/python3.10/site-packages (from gdown) (4.12.3) Requirement already satisfied: filelock in /opt/conda/lib/python3.10/site-packages (from gdown) (3.15.1) Requirement already satisfied: requests[socks] in /opt/conda/lib/python3.10/site-packages (from gdown) (2.32.3) Requirement already satisfied: tqdm in /opt/conda/lib/python3.10/site-packages (from gdown) (4.66.4) Requirement already satisfied: soupsieve>1.2 in /opt/conda/lib/python3.10/site-packages (from beautifulsoup4->gdown) (2.5) Requirement already satisfied: charset-normalizer<4,>=2 in /opt/conda/lib/python3.10/site-packages (from requests[socks]->gdown) (3.3.2) Requirement already satisfied: idna<4,>=2.5 in /opt/conda/lib/python3.10/site-packages (from requests[socks]->gdown) (3.7) Requirement already satisfied: urllib3<3,>=1.21.1 in /opt/conda/lib/python3.10/site-packages (from requests[socks]->gdown) (1.26.18) Requirement already satisfied: certifi>=2017.4.17 in /opt/conda/lib/python3.10/site-packages (from requests[socks]->gdown) (2024.8.30) Requirement already satisfied: PySocks!=1.5.7,>=1.5.6 in /opt/conda/lib/python3.10/site-packages (from requests[socks]->gdown) (1.7.1)
1.3 Necessary Imports¶
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
import cv2
import tensorflow as tf
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tqdm import tqdm
import os
from sklearn.utils import shuffle
from sklearn.model_selection import train_test_split
from tensorflow.keras.applications import EfficientNetB0
from tensorflow.keras.callbacks import EarlyStopping, ReduceLROnPlateau, TensorBoard, ModelCheckpoint
from sklearn.metrics import classification_report,confusion_matrix,ConfusionMatrixDisplay, accuracy_score
import ipywidgets as widgets
import io
from PIL import Image
from IPython.display import display,clear_output
1.4 Setting up color pallets¶
colors_dark = ["#1F1F1F", "#313131", '#636363', '#AEAEAE', '#DADADA']
colors_red = ["#331313", "#582626", '#9E1717', '#D35151', '#E9B4B4']
colors_green = ['#01411C','#4B6F44','#4F7942','#74C365','#D0F0C0']
sns.palplot(colors_dark)
sns.palplot(colors_green)
sns.palplot(colors_red)
1.5 Setting up label/class names¶
labels = ['anthracnose','gumosis','healthy','leaf miner',"red rust"]
CLASS_NAMES = labels
2. Exploratory data analysis¶
2.1 Identifying the proportion of each class.¶
train_path = '/kaggle/input/cashew-image-dataset/Cashew/train_set'
test_path = '/kaggle/input/cashew-image-dataset/Cashew/test_set'
category_counts = {}
for category in os.listdir(train_path):
category_folder = os.path.join(train_path, category)
if os.path.isdir(category_folder):
train_count = len(os.listdir(category_folder))
category_counts[category] = train_count
for category in os.listdir(test_path):
category_folder = os.path.join(test_path, category)
if os.path.isdir(category_folder):
test_count = len(os.listdir(category_folder))
category_counts[category] = category_counts.get(category, 0) + test_count
labels = list(category_counts.keys())
counts = list(category_counts.values())
colors = sns.color_palette("pastel", len(labels))
plt.figure(figsize=(8, 8))
plt.pie(counts, labels=labels, autopct='%1.1f%%', startangle=90, colors=colors)
plt.title("Image Distribution Across Categories", fontsize=16)
plt.show()
Results : As we can see, data set is pretty much balanced except the gumosis class.
2.2 Identifying the number of images in each class¶
plt.figure(figsize=(12, 6))
sns.barplot(x=counts, y=labels, palette="viridis", orient="h")
for i, count in enumerate(counts):
plt.text(count + 0.5, i, str(count), va='center', fontsize=10, fontweight='bold', color='black')
plt.title("Number of Images Per Category", fontsize=18, fontweight='bold')
plt.xlabel("Number of Images", fontsize=14)
plt.ylabel("Categories", fontsize=14)
plt.xticks(fontsize=12)
plt.yticks(fontsize=12)
plt.grid(axis='x', linestyle='--', alpha=0.7)
plt.tight_layout()
# Show the chart
plt.show()
2.3 Identifying the Image Dimension Distribution¶
image_dimensions = []
for category in os.listdir(train_path):
category_folder = os.path.join(train_path, category)
if os.path.isdir(category_folder):
for img in os.listdir(category_folder):
img_path = os.path.join(category_folder, img)
img_dim = plt.imread(img_path).shape
image_dimensions.append(img_dim[:2]) # (height, width)
# Convert to DataFrame for analysis
import pandas as pd
df_dims = pd.DataFrame(image_dimensions, columns=["Height", "Width"])
sns.boxplot(data=df_dims)
plt.title("Image Dimension Distribution")
plt.show()
2.4 Visualizing samples from each class¶
X_train = []
y_train = []
image_size = 224
for i in labels:
folderPath = os.path.join('/kaggle/input/cashew-image-dataset/Cashew','train_set',i)
for j in tqdm(os.listdir(folderPath)):
img = cv2.imread(os.path.join(folderPath,j))
img = cv2.resize(img,(image_size, image_size))
X_train.append(img)
y_train.append(i)
for i in labels:
folderPath = os.path.join('/kaggle/input/cashew-image-dataset/Cashew','test_set',i)
for j in tqdm(os.listdir(folderPath)):
img = cv2.imread(os.path.join(folderPath,j))
img = cv2.resize(img,(image_size,image_size))
X_train.append(img)
y_train.append(i)
X_train = np.array(X_train)
y_train = np.array(y_train)
100%|██████████| 4751/4751 [00:12<00:00, 367.07it/s] 100%|██████████| 3102/3102 [00:08<00:00, 384.12it/s] 100%|██████████| 5877/5877 [00:15<00:00, 374.38it/s] 100%|██████████| 1714/1714 [00:04<00:00, 349.30it/s] 100%|██████████| 3466/3466 [00:09<00:00, 361.17it/s] 100%|██████████| 1815/1815 [00:06<00:00, 261.56it/s] 100%|██████████| 1838/1838 [00:06<00:00, 275.41it/s] 100%|██████████| 1336/1336 [00:04<00:00, 294.36it/s] 100%|██████████| 425/425 [00:01<00:00, 277.75it/s] 100%|██████████| 1487/1487 [00:05<00:00, 292.17it/s]
k=0
fig, ax = plt.subplots(1,5,figsize=(20,20))
fig.text(s='Sample Image From Each Class',size=18,fontweight='bold',
fontname='monospace',color=colors_dark[1],y=0.62,x=0.4,alpha=0.8)
for i in labels:
j=0
while True :
if y_train[j]==i:
ax[k].imshow(X_train[j])
ax[k].set_title(y_train[j])
ax[k].axis('off')
k+=1
break
j+=1
3. Getting the training and testing data ready¶
3.1 Shuffling the data for the randomness¶
X_train, y_train = shuffle(X_train,y_train, random_state=101)
X_train.shape # Shape of the X_train
# (Number of images , image size, image size, number of color channels)
(25811, 224, 224, 3)
X_train,X_test,y_train,y_test = train_test_split(X_train,y_train, test_size=0.2,random_state=101)
3.3 One Hot Encoding of the labels¶
y_train_new = []
for i in y_train:
y_train_new.append(labels.index(i))
y_train = y_train_new
y_train = tf.keras.utils.to_categorical(y_train)
y_test_new = []
for i in y_test:
y_test_new.append(labels.index(i))
y_test = y_test_new
y_test = tf.keras.utils.to_categorical(y_test)
4. Setting up the neural net and training¶
4.1 Transfer Learning¶
Deep convolutional neural network models may take days or even weeks to train on very large datasets.
A way to short-cut this process is to re-use the model weights from pre-trained models that were developed for standard computer vision benchmark datasets, such as the ImageNet image recognition tasks. Top performing models can be downloaded and used directly, or integrated into a new model for your own computer vision problems.
In this Project, We'll be using the EfficientNetB0 model which will use the weights from the ImageNet dataset.
The include_top parameter is set to False so that the network doesn't include the top layer/output layer from the pre-built model which allows us to add our own output layer depending upon our use case!
effnet = EfficientNetB0(weights='imagenet',include_top=False,input_shape=(image_size,image_size,3))
4.2 Layers¶
GlobalAveragePooling2D -> This layer acts similar to the Max Pooling layer in CNNs, the only difference being is that it uses the Average values instead of the Max value while pooling. This really helps in decreasing the computational load on the machine while training.
Dropout -> This layer omits some of the neurons at each step from the layer making the neurons more independent from the neibouring neurons. It helps in avoiding overfitting. Neurons to be ommitted are selected at random. The rate parameter is the liklihood of a neuron activation being set to 0, thus dropping out the neuron
Dense -> This is the output layer which classifies the image into 1 of the 5 possible classes. It uses the softmax function which is a generalization of the sigmoid function.
model = effnet.output
model = tf.keras.layers.GlobalAveragePooling2D()(model)
model = tf.keras.layers.Dropout(rate=0.5)(model)
model = tf.keras.layers.Dense(5,activation='softmax')(model)
model = tf.keras.models.Model(inputs=effnet.input, outputs = model)
4.3 Model summary¶
model.summary()
Model: "functional_3"
┏━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━┓ ┃ Layer (type) ┃ Output Shape ┃ Param # ┃ Connected to ┃ ┡━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━┩ │ input_layer_1 │ (None, 224, 224, │ 0 │ - │ │ (InputLayer) │ 3) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ rescaling_2 │ (None, 224, 224, │ 0 │ input_layer_1[0]… │ │ (Rescaling) │ 3) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ normalization_1 │ (None, 224, 224, │ 7 │ rescaling_2[0][0] │ │ (Normalization) │ 3) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ rescaling_3 │ (None, 224, 224, │ 0 │ normalization_1[… │ │ (Rescaling) │ 3) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ stem_conv_pad │ (None, 225, 225, │ 0 │ rescaling_3[0][0] │ │ (ZeroPadding2D) │ 3) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ stem_conv (Conv2D) │ (None, 112, 112, │ 864 │ stem_conv_pad[0]… │ │ │ 32) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ stem_bn │ (None, 112, 112, │ 128 │ stem_conv[0][0] │ │ (BatchNormalizatio… │ 32) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ stem_activation │ (None, 112, 112, │ 0 │ stem_bn[0][0] │ │ (Activation) │ 32) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block1a_dwconv │ (None, 112, 112, │ 288 │ stem_activation[… │ │ (DepthwiseConv2D) │ 32) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block1a_bn │ (None, 112, 112, │ 128 │ block1a_dwconv[0… │ │ (BatchNormalizatio… │ 32) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block1a_activation │ (None, 112, 112, │ 0 │ block1a_bn[0][0] │ │ (Activation) │ 32) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block1a_se_squeeze │ (None, 32) │ 0 │ block1a_activati… │ │ (GlobalAveragePool… │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block1a_se_reshape │ (None, 1, 1, 32) │ 0 │ block1a_se_squee… │ │ (Reshape) │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block1a_se_reduce │ (None, 1, 1, 8) │ 264 │ block1a_se_resha… │ │ (Conv2D) │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block1a_se_expand │ (None, 1, 1, 32) │ 288 │ block1a_se_reduc… │ │ (Conv2D) │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block1a_se_excite │ (None, 112, 112, │ 0 │ block1a_activati… │ │ (Multiply) │ 32) │ │ block1a_se_expan… │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block1a_project_co… │ (None, 112, 112, │ 512 │ block1a_se_excit… │ │ (Conv2D) │ 16) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block1a_project_bn │ (None, 112, 112, │ 64 │ block1a_project_… │ │ (BatchNormalizatio… │ 16) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block2a_expand_conv │ (None, 112, 112, │ 1,536 │ block1a_project_… │ │ (Conv2D) │ 96) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block2a_expand_bn │ (None, 112, 112, │ 384 │ block2a_expand_c… │ │ (BatchNormalizatio… │ 96) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block2a_expand_act… │ (None, 112, 112, │ 0 │ block2a_expand_b… │ │ (Activation) │ 96) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block2a_dwconv_pad │ (None, 113, 113, │ 0 │ block2a_expand_a… │ │ (ZeroPadding2D) │ 96) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block2a_dwconv │ (None, 56, 56, │ 864 │ block2a_dwconv_p… │ │ (DepthwiseConv2D) │ 96) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block2a_bn │ (None, 56, 56, │ 384 │ block2a_dwconv[0… │ │ (BatchNormalizatio… │ 96) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block2a_activation │ (None, 56, 56, │ 0 │ block2a_bn[0][0] │ │ (Activation) │ 96) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block2a_se_squeeze │ (None, 96) │ 0 │ block2a_activati… │ │ (GlobalAveragePool… │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block2a_se_reshape │ (None, 1, 1, 96) │ 0 │ block2a_se_squee… │ │ (Reshape) │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block2a_se_reduce │ (None, 1, 1, 4) │ 388 │ block2a_se_resha… │ │ (Conv2D) │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block2a_se_expand │ (None, 1, 1, 96) │ 480 │ block2a_se_reduc… │ │ (Conv2D) │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block2a_se_excite │ (None, 56, 56, │ 0 │ block2a_activati… │ │ (Multiply) │ 96) │ │ block2a_se_expan… │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block2a_project_co… │ (None, 56, 56, │ 2,304 │ block2a_se_excit… │ │ (Conv2D) │ 24) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block2a_project_bn │ (None, 56, 56, │ 96 │ block2a_project_… │ │ (BatchNormalizatio… │ 24) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block2b_expand_conv │ (None, 56, 56, │ 3,456 │ block2a_project_… │ │ (Conv2D) │ 144) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block2b_expand_bn │ (None, 56, 56, │ 576 │ block2b_expand_c… │ │ (BatchNormalizatio… │ 144) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block2b_expand_act… │ (None, 56, 56, │ 0 │ block2b_expand_b… │ │ (Activation) │ 144) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block2b_dwconv │ (None, 56, 56, │ 1,296 │ block2b_expand_a… │ │ (DepthwiseConv2D) │ 144) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block2b_bn │ (None, 56, 56, │ 576 │ block2b_dwconv[0… │ │ (BatchNormalizatio… │ 144) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block2b_activation │ (None, 56, 56, │ 0 │ block2b_bn[0][0] │ │ (Activation) │ 144) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block2b_se_squeeze │ (None, 144) │ 0 │ block2b_activati… │ │ (GlobalAveragePool… │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block2b_se_reshape │ (None, 1, 1, 144) │ 0 │ block2b_se_squee… │ │ (Reshape) │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block2b_se_reduce │ (None, 1, 1, 6) │ 870 │ block2b_se_resha… │ │ (Conv2D) │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block2b_se_expand │ (None, 1, 1, 144) │ 1,008 │ block2b_se_reduc… │ │ (Conv2D) │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block2b_se_excite │ (None, 56, 56, │ 0 │ block2b_activati… │ │ (Multiply) │ 144) │ │ block2b_se_expan… │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block2b_project_co… │ (None, 56, 56, │ 3,456 │ block2b_se_excit… │ │ (Conv2D) │ 24) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block2b_project_bn │ (None, 56, 56, │ 96 │ block2b_project_… │ │ (BatchNormalizatio… │ 24) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block2b_drop │ (None, 56, 56, │ 0 │ block2b_project_… │ │ (Dropout) │ 24) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block2b_add (Add) │ (None, 56, 56, │ 0 │ block2b_drop[0][… │ │ │ 24) │ │ block2a_project_… │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block3a_expand_conv │ (None, 56, 56, │ 3,456 │ block2b_add[0][0] │ │ (Conv2D) │ 144) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block3a_expand_bn │ (None, 56, 56, │ 576 │ block3a_expand_c… │ │ (BatchNormalizatio… │ 144) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block3a_expand_act… │ (None, 56, 56, │ 0 │ block3a_expand_b… │ │ (Activation) │ 144) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block3a_dwconv_pad │ (None, 59, 59, │ 0 │ block3a_expand_a… │ │ (ZeroPadding2D) │ 144) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block3a_dwconv │ (None, 28, 28, │ 3,600 │ block3a_dwconv_p… │ │ (DepthwiseConv2D) │ 144) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block3a_bn │ (None, 28, 28, │ 576 │ block3a_dwconv[0… │ │ (BatchNormalizatio… │ 144) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block3a_activation │ (None, 28, 28, │ 0 │ block3a_bn[0][0] │ │ (Activation) │ 144) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block3a_se_squeeze │ (None, 144) │ 0 │ block3a_activati… │ │ (GlobalAveragePool… │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block3a_se_reshape │ (None, 1, 1, 144) │ 0 │ block3a_se_squee… │ │ (Reshape) │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block3a_se_reduce │ (None, 1, 1, 6) │ 870 │ block3a_se_resha… │ │ (Conv2D) │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block3a_se_expand │ (None, 1, 1, 144) │ 1,008 │ block3a_se_reduc… │ │ (Conv2D) │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block3a_se_excite │ (None, 28, 28, │ 0 │ block3a_activati… │ │ (Multiply) │ 144) │ │ block3a_se_expan… │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block3a_project_co… │ (None, 28, 28, │ 5,760 │ block3a_se_excit… │ │ (Conv2D) │ 40) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block3a_project_bn │ (None, 28, 28, │ 160 │ block3a_project_… │ │ (BatchNormalizatio… │ 40) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block3b_expand_conv │ (None, 28, 28, │ 9,600 │ block3a_project_… │ │ (Conv2D) │ 240) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block3b_expand_bn │ (None, 28, 28, │ 960 │ block3b_expand_c… │ │ (BatchNormalizatio… │ 240) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block3b_expand_act… │ (None, 28, 28, │ 0 │ block3b_expand_b… │ │ (Activation) │ 240) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block3b_dwconv │ (None, 28, 28, │ 6,000 │ block3b_expand_a… │ │ (DepthwiseConv2D) │ 240) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block3b_bn │ (None, 28, 28, │ 960 │ block3b_dwconv[0… │ │ (BatchNormalizatio… │ 240) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block3b_activation │ (None, 28, 28, │ 0 │ block3b_bn[0][0] │ │ (Activation) │ 240) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block3b_se_squeeze │ (None, 240) │ 0 │ block3b_activati… │ │ (GlobalAveragePool… │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block3b_se_reshape │ (None, 1, 1, 240) │ 0 │ block3b_se_squee… │ │ (Reshape) │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block3b_se_reduce │ (None, 1, 1, 10) │ 2,410 │ block3b_se_resha… │ │ (Conv2D) │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block3b_se_expand │ (None, 1, 1, 240) │ 2,640 │ block3b_se_reduc… │ │ (Conv2D) │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block3b_se_excite │ (None, 28, 28, │ 0 │ block3b_activati… │ │ (Multiply) │ 240) │ │ block3b_se_expan… │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block3b_project_co… │ (None, 28, 28, │ 9,600 │ block3b_se_excit… │ │ (Conv2D) │ 40) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block3b_project_bn │ (None, 28, 28, │ 160 │ block3b_project_… │ │ (BatchNormalizatio… │ 40) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block3b_drop │ (None, 28, 28, │ 0 │ block3b_project_… │ │ (Dropout) │ 40) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block3b_add (Add) │ (None, 28, 28, │ 0 │ block3b_drop[0][… │ │ │ 40) │ │ block3a_project_… │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block4a_expand_conv │ (None, 28, 28, │ 9,600 │ block3b_add[0][0] │ │ (Conv2D) │ 240) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block4a_expand_bn │ (None, 28, 28, │ 960 │ block4a_expand_c… │ │ (BatchNormalizatio… │ 240) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block4a_expand_act… │ (None, 28, 28, │ 0 │ block4a_expand_b… │ │ (Activation) │ 240) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block4a_dwconv_pad │ (None, 29, 29, │ 0 │ block4a_expand_a… │ │ (ZeroPadding2D) │ 240) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block4a_dwconv │ (None, 14, 14, │ 2,160 │ block4a_dwconv_p… │ │ (DepthwiseConv2D) │ 240) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block4a_bn │ (None, 14, 14, │ 960 │ block4a_dwconv[0… │ │ (BatchNormalizatio… │ 240) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block4a_activation │ (None, 14, 14, │ 0 │ block4a_bn[0][0] │ │ (Activation) │ 240) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block4a_se_squeeze │ (None, 240) │ 0 │ block4a_activati… │ │ (GlobalAveragePool… │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block4a_se_reshape │ (None, 1, 1, 240) │ 0 │ block4a_se_squee… │ │ (Reshape) │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block4a_se_reduce │ (None, 1, 1, 10) │ 2,410 │ block4a_se_resha… │ │ (Conv2D) │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block4a_se_expand │ (None, 1, 1, 240) │ 2,640 │ block4a_se_reduc… │ │ (Conv2D) │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block4a_se_excite │ (None, 14, 14, │ 0 │ block4a_activati… │ │ (Multiply) │ 240) │ │ block4a_se_expan… │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block4a_project_co… │ (None, 14, 14, │ 19,200 │ block4a_se_excit… │ │ (Conv2D) │ 80) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block4a_project_bn │ (None, 14, 14, │ 320 │ block4a_project_… │ │ (BatchNormalizatio… │ 80) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block4b_expand_conv │ (None, 14, 14, │ 38,400 │ block4a_project_… │ │ (Conv2D) │ 480) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block4b_expand_bn │ (None, 14, 14, │ 1,920 │ block4b_expand_c… │ │ (BatchNormalizatio… │ 480) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block4b_expand_act… │ (None, 14, 14, │ 0 │ block4b_expand_b… │ │ (Activation) │ 480) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block4b_dwconv │ (None, 14, 14, │ 4,320 │ block4b_expand_a… │ │ (DepthwiseConv2D) │ 480) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block4b_bn │ (None, 14, 14, │ 1,920 │ block4b_dwconv[0… │ │ (BatchNormalizatio… │ 480) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block4b_activation │ (None, 14, 14, │ 0 │ block4b_bn[0][0] │ │ (Activation) │ 480) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block4b_se_squeeze │ (None, 480) │ 0 │ block4b_activati… │ │ (GlobalAveragePool… │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block4b_se_reshape │ (None, 1, 1, 480) │ 0 │ block4b_se_squee… │ │ (Reshape) │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block4b_se_reduce │ (None, 1, 1, 20) │ 9,620 │ block4b_se_resha… │ │ (Conv2D) │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block4b_se_expand │ (None, 1, 1, 480) │ 10,080 │ block4b_se_reduc… │ │ (Conv2D) │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block4b_se_excite │ (None, 14, 14, │ 0 │ block4b_activati… │ │ (Multiply) │ 480) │ │ block4b_se_expan… │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block4b_project_co… │ (None, 14, 14, │ 38,400 │ block4b_se_excit… │ │ (Conv2D) │ 80) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block4b_project_bn │ (None, 14, 14, │ 320 │ block4b_project_… │ │ (BatchNormalizatio… │ 80) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block4b_drop │ (None, 14, 14, │ 0 │ block4b_project_… │ │ (Dropout) │ 80) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block4b_add (Add) │ (None, 14, 14, │ 0 │ block4b_drop[0][… │ │ │ 80) │ │ block4a_project_… │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block4c_expand_conv │ (None, 14, 14, │ 38,400 │ block4b_add[0][0] │ │ (Conv2D) │ 480) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block4c_expand_bn │ (None, 14, 14, │ 1,920 │ block4c_expand_c… │ │ (BatchNormalizatio… │ 480) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block4c_expand_act… │ (None, 14, 14, │ 0 │ block4c_expand_b… │ │ (Activation) │ 480) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block4c_dwconv │ (None, 14, 14, │ 4,320 │ block4c_expand_a… │ │ (DepthwiseConv2D) │ 480) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block4c_bn │ (None, 14, 14, │ 1,920 │ block4c_dwconv[0… │ │ (BatchNormalizatio… │ 480) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block4c_activation │ (None, 14, 14, │ 0 │ block4c_bn[0][0] │ │ (Activation) │ 480) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block4c_se_squeeze │ (None, 480) │ 0 │ block4c_activati… │ │ (GlobalAveragePool… │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block4c_se_reshape │ (None, 1, 1, 480) │ 0 │ block4c_se_squee… │ │ (Reshape) │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block4c_se_reduce │ (None, 1, 1, 20) │ 9,620 │ block4c_se_resha… │ │ (Conv2D) │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block4c_se_expand │ (None, 1, 1, 480) │ 10,080 │ block4c_se_reduc… │ │ (Conv2D) │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block4c_se_excite │ (None, 14, 14, │ 0 │ block4c_activati… │ │ (Multiply) │ 480) │ │ block4c_se_expan… │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block4c_project_co… │ (None, 14, 14, │ 38,400 │ block4c_se_excit… │ │ (Conv2D) │ 80) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block4c_project_bn │ (None, 14, 14, │ 320 │ block4c_project_… │ │ (BatchNormalizatio… │ 80) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block4c_drop │ (None, 14, 14, │ 0 │ block4c_project_… │ │ (Dropout) │ 80) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block4c_add (Add) │ (None, 14, 14, │ 0 │ block4c_drop[0][… │ │ │ 80) │ │ block4b_add[0][0] │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block5a_expand_conv │ (None, 14, 14, │ 38,400 │ block4c_add[0][0] │ │ (Conv2D) │ 480) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block5a_expand_bn │ (None, 14, 14, │ 1,920 │ block5a_expand_c… │ │ (BatchNormalizatio… │ 480) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block5a_expand_act… │ (None, 14, 14, │ 0 │ block5a_expand_b… │ │ (Activation) │ 480) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block5a_dwconv │ (None, 14, 14, │ 12,000 │ block5a_expand_a… │ │ (DepthwiseConv2D) │ 480) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block5a_bn │ (None, 14, 14, │ 1,920 │ block5a_dwconv[0… │ │ (BatchNormalizatio… │ 480) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block5a_activation │ (None, 14, 14, │ 0 │ block5a_bn[0][0] │ │ (Activation) │ 480) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block5a_se_squeeze │ (None, 480) │ 0 │ block5a_activati… │ │ (GlobalAveragePool… │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block5a_se_reshape │ (None, 1, 1, 480) │ 0 │ block5a_se_squee… │ │ (Reshape) │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block5a_se_reduce │ (None, 1, 1, 20) │ 9,620 │ block5a_se_resha… │ │ (Conv2D) │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block5a_se_expand │ (None, 1, 1, 480) │ 10,080 │ block5a_se_reduc… │ │ (Conv2D) │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block5a_se_excite │ (None, 14, 14, │ 0 │ block5a_activati… │ │ (Multiply) │ 480) │ │ block5a_se_expan… │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block5a_project_co… │ (None, 14, 14, │ 53,760 │ block5a_se_excit… │ │ (Conv2D) │ 112) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block5a_project_bn │ (None, 14, 14, │ 448 │ block5a_project_… │ │ (BatchNormalizatio… │ 112) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block5b_expand_conv │ (None, 14, 14, │ 75,264 │ block5a_project_… │ │ (Conv2D) │ 672) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block5b_expand_bn │ (None, 14, 14, │ 2,688 │ block5b_expand_c… │ │ (BatchNormalizatio… │ 672) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block5b_expand_act… │ (None, 14, 14, │ 0 │ block5b_expand_b… │ │ (Activation) │ 672) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block5b_dwconv │ (None, 14, 14, │ 16,800 │ block5b_expand_a… │ │ (DepthwiseConv2D) │ 672) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block5b_bn │ (None, 14, 14, │ 2,688 │ block5b_dwconv[0… │ │ (BatchNormalizatio… │ 672) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block5b_activation │ (None, 14, 14, │ 0 │ block5b_bn[0][0] │ │ (Activation) │ 672) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block5b_se_squeeze │ (None, 672) │ 0 │ block5b_activati… │ │ (GlobalAveragePool… │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block5b_se_reshape │ (None, 1, 1, 672) │ 0 │ block5b_se_squee… │ │ (Reshape) │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block5b_se_reduce │ (None, 1, 1, 28) │ 18,844 │ block5b_se_resha… │ │ (Conv2D) │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block5b_se_expand │ (None, 1, 1, 672) │ 19,488 │ block5b_se_reduc… │ │ (Conv2D) │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block5b_se_excite │ (None, 14, 14, │ 0 │ block5b_activati… │ │ (Multiply) │ 672) │ │ block5b_se_expan… │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block5b_project_co… │ (None, 14, 14, │ 75,264 │ block5b_se_excit… │ │ (Conv2D) │ 112) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block5b_project_bn │ (None, 14, 14, │ 448 │ block5b_project_… │ │ (BatchNormalizatio… │ 112) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block5b_drop │ (None, 14, 14, │ 0 │ block5b_project_… │ │ (Dropout) │ 112) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block5b_add (Add) │ (None, 14, 14, │ 0 │ block5b_drop[0][… │ │ │ 112) │ │ block5a_project_… │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block5c_expand_conv │ (None, 14, 14, │ 75,264 │ block5b_add[0][0] │ │ (Conv2D) │ 672) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block5c_expand_bn │ (None, 14, 14, │ 2,688 │ block5c_expand_c… │ │ (BatchNormalizatio… │ 672) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block5c_expand_act… │ (None, 14, 14, │ 0 │ block5c_expand_b… │ │ (Activation) │ 672) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block5c_dwconv │ (None, 14, 14, │ 16,800 │ block5c_expand_a… │ │ (DepthwiseConv2D) │ 672) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block5c_bn │ (None, 14, 14, │ 2,688 │ block5c_dwconv[0… │ │ (BatchNormalizatio… │ 672) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block5c_activation │ (None, 14, 14, │ 0 │ block5c_bn[0][0] │ │ (Activation) │ 672) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block5c_se_squeeze │ (None, 672) │ 0 │ block5c_activati… │ │ (GlobalAveragePool… │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block5c_se_reshape │ (None, 1, 1, 672) │ 0 │ block5c_se_squee… │ │ (Reshape) │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block5c_se_reduce │ (None, 1, 1, 28) │ 18,844 │ block5c_se_resha… │ │ (Conv2D) │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block5c_se_expand │ (None, 1, 1, 672) │ 19,488 │ block5c_se_reduc… │ │ (Conv2D) │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block5c_se_excite │ (None, 14, 14, │ 0 │ block5c_activati… │ │ (Multiply) │ 672) │ │ block5c_se_expan… │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block5c_project_co… │ (None, 14, 14, │ 75,264 │ block5c_se_excit… │ │ (Conv2D) │ 112) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block5c_project_bn │ (None, 14, 14, │ 448 │ block5c_project_… │ │ (BatchNormalizatio… │ 112) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block5c_drop │ (None, 14, 14, │ 0 │ block5c_project_… │ │ (Dropout) │ 112) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block5c_add (Add) │ (None, 14, 14, │ 0 │ block5c_drop[0][… │ │ │ 112) │ │ block5b_add[0][0] │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block6a_expand_conv │ (None, 14, 14, │ 75,264 │ block5c_add[0][0] │ │ (Conv2D) │ 672) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block6a_expand_bn │ (None, 14, 14, │ 2,688 │ block6a_expand_c… │ │ (BatchNormalizatio… │ 672) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block6a_expand_act… │ (None, 14, 14, │ 0 │ block6a_expand_b… │ │ (Activation) │ 672) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block6a_dwconv_pad │ (None, 17, 17, │ 0 │ block6a_expand_a… │ │ (ZeroPadding2D) │ 672) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block6a_dwconv │ (None, 7, 7, 672) │ 16,800 │ block6a_dwconv_p… │ │ (DepthwiseConv2D) │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block6a_bn │ (None, 7, 7, 672) │ 2,688 │ block6a_dwconv[0… │ │ (BatchNormalizatio… │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block6a_activation │ (None, 7, 7, 672) │ 0 │ block6a_bn[0][0] │ │ (Activation) │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block6a_se_squeeze │ (None, 672) │ 0 │ block6a_activati… │ │ (GlobalAveragePool… │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block6a_se_reshape │ (None, 1, 1, 672) │ 0 │ block6a_se_squee… │ │ (Reshape) │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block6a_se_reduce │ (None, 1, 1, 28) │ 18,844 │ block6a_se_resha… │ │ (Conv2D) │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block6a_se_expand │ (None, 1, 1, 672) │ 19,488 │ block6a_se_reduc… │ │ (Conv2D) │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block6a_se_excite │ (None, 7, 7, 672) │ 0 │ block6a_activati… │ │ (Multiply) │ │ │ block6a_se_expan… │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block6a_project_co… │ (None, 7, 7, 192) │ 129,024 │ block6a_se_excit… │ │ (Conv2D) │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block6a_project_bn │ (None, 7, 7, 192) │ 768 │ block6a_project_… │ │ (BatchNormalizatio… │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block6b_expand_conv │ (None, 7, 7, │ 221,184 │ block6a_project_… │ │ (Conv2D) │ 1152) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block6b_expand_bn │ (None, 7, 7, │ 4,608 │ block6b_expand_c… │ │ (BatchNormalizatio… │ 1152) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block6b_expand_act… │ (None, 7, 7, │ 0 │ block6b_expand_b… │ │ (Activation) │ 1152) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block6b_dwconv │ (None, 7, 7, │ 28,800 │ block6b_expand_a… │ │ (DepthwiseConv2D) │ 1152) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block6b_bn │ (None, 7, 7, │ 4,608 │ block6b_dwconv[0… │ │ (BatchNormalizatio… │ 1152) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block6b_activation │ (None, 7, 7, │ 0 │ block6b_bn[0][0] │ │ (Activation) │ 1152) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block6b_se_squeeze │ (None, 1152) │ 0 │ block6b_activati… │ │ (GlobalAveragePool… │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block6b_se_reshape │ (None, 1, 1, │ 0 │ block6b_se_squee… │ │ (Reshape) │ 1152) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block6b_se_reduce │ (None, 1, 1, 48) │ 55,344 │ block6b_se_resha… │ │ (Conv2D) │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block6b_se_expand │ (None, 1, 1, │ 56,448 │ block6b_se_reduc… │ │ (Conv2D) │ 1152) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block6b_se_excite │ (None, 7, 7, │ 0 │ block6b_activati… │ │ (Multiply) │ 1152) │ │ block6b_se_expan… │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block6b_project_co… │ (None, 7, 7, 192) │ 221,184 │ block6b_se_excit… │ │ (Conv2D) │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block6b_project_bn │ (None, 7, 7, 192) │ 768 │ block6b_project_… │ │ (BatchNormalizatio… │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block6b_drop │ (None, 7, 7, 192) │ 0 │ block6b_project_… │ │ (Dropout) │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block6b_add (Add) │ (None, 7, 7, 192) │ 0 │ block6b_drop[0][… │ │ │ │ │ block6a_project_… │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block6c_expand_conv │ (None, 7, 7, │ 221,184 │ block6b_add[0][0] │ │ (Conv2D) │ 1152) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block6c_expand_bn │ (None, 7, 7, │ 4,608 │ block6c_expand_c… │ │ (BatchNormalizatio… │ 1152) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block6c_expand_act… │ (None, 7, 7, │ 0 │ block6c_expand_b… │ │ (Activation) │ 1152) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block6c_dwconv │ (None, 7, 7, │ 28,800 │ block6c_expand_a… │ │ (DepthwiseConv2D) │ 1152) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block6c_bn │ (None, 7, 7, │ 4,608 │ block6c_dwconv[0… │ │ (BatchNormalizatio… │ 1152) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block6c_activation │ (None, 7, 7, │ 0 │ block6c_bn[0][0] │ │ (Activation) │ 1152) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block6c_se_squeeze │ (None, 1152) │ 0 │ block6c_activati… │ │ (GlobalAveragePool… │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block6c_se_reshape │ (None, 1, 1, │ 0 │ block6c_se_squee… │ │ (Reshape) │ 1152) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block6c_se_reduce │ (None, 1, 1, 48) │ 55,344 │ block6c_se_resha… │ │ (Conv2D) │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block6c_se_expand │ (None, 1, 1, │ 56,448 │ block6c_se_reduc… │ │ (Conv2D) │ 1152) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block6c_se_excite │ (None, 7, 7, │ 0 │ block6c_activati… │ │ (Multiply) │ 1152) │ │ block6c_se_expan… │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block6c_project_co… │ (None, 7, 7, 192) │ 221,184 │ block6c_se_excit… │ │ (Conv2D) │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block6c_project_bn │ (None, 7, 7, 192) │ 768 │ block6c_project_… │ │ (BatchNormalizatio… │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block6c_drop │ (None, 7, 7, 192) │ 0 │ block6c_project_… │ │ (Dropout) │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block6c_add (Add) │ (None, 7, 7, 192) │ 0 │ block6c_drop[0][… │ │ │ │ │ block6b_add[0][0] │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block6d_expand_conv │ (None, 7, 7, │ 221,184 │ block6c_add[0][0] │ │ (Conv2D) │ 1152) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block6d_expand_bn │ (None, 7, 7, │ 4,608 │ block6d_expand_c… │ │ (BatchNormalizatio… │ 1152) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block6d_expand_act… │ (None, 7, 7, │ 0 │ block6d_expand_b… │ │ (Activation) │ 1152) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block6d_dwconv │ (None, 7, 7, │ 28,800 │ block6d_expand_a… │ │ (DepthwiseConv2D) │ 1152) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block6d_bn │ (None, 7, 7, │ 4,608 │ block6d_dwconv[0… │ │ (BatchNormalizatio… │ 1152) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block6d_activation │ (None, 7, 7, │ 0 │ block6d_bn[0][0] │ │ (Activation) │ 1152) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block6d_se_squeeze │ (None, 1152) │ 0 │ block6d_activati… │ │ (GlobalAveragePool… │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block6d_se_reshape │ (None, 1, 1, │ 0 │ block6d_se_squee… │ │ (Reshape) │ 1152) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block6d_se_reduce │ (None, 1, 1, 48) │ 55,344 │ block6d_se_resha… │ │ (Conv2D) │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block6d_se_expand │ (None, 1, 1, │ 56,448 │ block6d_se_reduc… │ │ (Conv2D) │ 1152) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block6d_se_excite │ (None, 7, 7, │ 0 │ block6d_activati… │ │ (Multiply) │ 1152) │ │ block6d_se_expan… │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block6d_project_co… │ (None, 7, 7, 192) │ 221,184 │ block6d_se_excit… │ │ (Conv2D) │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block6d_project_bn │ (None, 7, 7, 192) │ 768 │ block6d_project_… │ │ (BatchNormalizatio… │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block6d_drop │ (None, 7, 7, 192) │ 0 │ block6d_project_… │ │ (Dropout) │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block6d_add (Add) │ (None, 7, 7, 192) │ 0 │ block6d_drop[0][… │ │ │ │ │ block6c_add[0][0] │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block7a_expand_conv │ (None, 7, 7, │ 221,184 │ block6d_add[0][0] │ │ (Conv2D) │ 1152) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block7a_expand_bn │ (None, 7, 7, │ 4,608 │ block7a_expand_c… │ │ (BatchNormalizatio… │ 1152) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block7a_expand_act… │ (None, 7, 7, │ 0 │ block7a_expand_b… │ │ (Activation) │ 1152) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block7a_dwconv │ (None, 7, 7, │ 10,368 │ block7a_expand_a… │ │ (DepthwiseConv2D) │ 1152) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block7a_bn │ (None, 7, 7, │ 4,608 │ block7a_dwconv[0… │ │ (BatchNormalizatio… │ 1152) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block7a_activation │ (None, 7, 7, │ 0 │ block7a_bn[0][0] │ │ (Activation) │ 1152) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block7a_se_squeeze │ (None, 1152) │ 0 │ block7a_activati… │ │ (GlobalAveragePool… │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block7a_se_reshape │ (None, 1, 1, │ 0 │ block7a_se_squee… │ │ (Reshape) │ 1152) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block7a_se_reduce │ (None, 1, 1, 48) │ 55,344 │ block7a_se_resha… │ │ (Conv2D) │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block7a_se_expand │ (None, 1, 1, │ 56,448 │ block7a_se_reduc… │ │ (Conv2D) │ 1152) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block7a_se_excite │ (None, 7, 7, │ 0 │ block7a_activati… │ │ (Multiply) │ 1152) │ │ block7a_se_expan… │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block7a_project_co… │ (None, 7, 7, 320) │ 368,640 │ block7a_se_excit… │ │ (Conv2D) │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block7a_project_bn │ (None, 7, 7, 320) │ 1,280 │ block7a_project_… │ │ (BatchNormalizatio… │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ top_conv (Conv2D) │ (None, 7, 7, │ 409,600 │ block7a_project_… │ │ │ 1280) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ top_bn │ (None, 7, 7, │ 5,120 │ top_conv[0][0] │ │ (BatchNormalizatio… │ 1280) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ top_activation │ (None, 7, 7, │ 0 │ top_bn[0][0] │ │ (Activation) │ 1280) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ global_average_poo… │ (None, 1280) │ 0 │ top_activation[0… │ │ (GlobalAveragePool… │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ dropout_1 (Dropout) │ (None, 1280) │ 0 │ global_average_p… │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ dense_1 (Dense) │ (None, 5) │ 6,405 │ dropout_1[0][0] │ └─────────────────────┴───────────────────┴────────────┴───────────────────┘
Total params: 4,055,976 (15.47 MB)
Trainable params: 4,013,953 (15.31 MB)
Non-trainable params: 42,023 (164.16 KB)
4.4 Model compilation¶
Loss function¶
CategoricalCrossentropy() : Categorical Cross Entropy is also known as Softmax Loss. It's a softmax activation plus a Cross-Entropy loss used for multiclass classification. Using this loss, we can train a Convolutional Neural Network to output a probability over the N classes for each image.
Optimizer Function¶
Adam() : The Adam optimizer (short for Adaptive Moment Estimation) is a widely used optimization algorithm in neural networks that combines the advantages of two popular optimization methods: Momentum and RMSProp. It is particularly effective for handling sparse gradients, noisy objectives, and non-stationary problems.
Metrics¶
Accuracy() : It is the ratio of number of correct predictions to the total number of input samples.
model.compile(loss='categorical_crossentropy',optimizer = 'Adam', metrics= ['accuracy'])
4.5 Getting the callback functions ready¶
Callback Functions¶
TensorBoard()
TensorBoardis a visualization tool provided by TensorFlow. It allows us to monitor various metrics like loss, accuracy, learning rate, and more during the training process.- In this setup:
- The logs are saved in a directory named
logs, which can be used to visualize training progress in TensorBoard.
- The logs are saved in a directory named
ModelCheckpoint()
- This callback is used to save the model during training. It ensures the best version of the model (based on a monitored metric) is saved automatically.
- In this setup:
- The model is saved with the filename
cashew-effnet.keras. - The metric being monitored is
val_accuracy(validation accuracy). - The model is saved only when it achieves the highest validation accuracy so far.
save_best_only=Trueensures only the best model is saved.verbose=1provides updates during the saving process.
- The model is saved with the filename
ReduceLROnPlateau()
- This callback reduces the learning rate when the monitored metric stops improving, helping the model converge better during later training stages.
- In this setup:
- The metric being monitored is
val_accuracy. - If the validation accuracy does not improve for 2 epochs, the learning rate is reduced by a factor of 0.3.
min_delta=0.001sets the minimum change in validation accuracy required to qualify as an improvement.verbose=1provides updates when the learning rate is reduced.
- The metric being monitored is
tensorboard = TensorBoard(log_dir = 'logs')
checkpoint = ModelCheckpoint("cashew-effnet.keras",monitor="val_accuracy",save_best_only=True,mode="auto",verbose=1)
reduce_lr = ReduceLROnPlateau(monitor = 'val_accuracy', factor = 0.3, patience = 2, min_delta = 0.001,
mode='auto',verbose=1)
4.6 Model fitting¶
Model.fit()¶
The fit() method is used to train the model on the training data (X_train, y_train) for a specified number of epochs and batch size. It also incorporates the callbacks defined earlier to enhance and monitor the training process.
Key Parameters¶
X_train, y_trainX_train: The input training data (features).y_train: The corresponding labels for the training data.
validation_split=0.2- This reserves 20% of the training data for validation to evaluate the model’s performance during training.
epochs=12- Specifies that the model will train for 12 iterations over the entire training dataset.
verbose=1- Displays detailed training progress, including loss, accuracy, and other metrics per epoch.
batch_size=32- The model processes 32 samples at a time before updating the weights, which helps balance memory usage and speed.
callbacks=[tensorboard, checkpoint, reduce_lr]- Integrates the previously defined callbacks to monitor and optimize the training process:
TensorBoard: Logs metrics for visualization in TensorBoard.ModelCheckpoint: Saves the best model based on validation accuracy.ReduceLROnPlateau: Dynamically adjusts the learning rate to improve convergence.
- Integrates the previously defined callbacks to monitor and optimize the training process:
By combining these settings, the training process is efficient, monitored, and adaptable, ensuring better performance.
history = model.fit(X_train,y_train,validation_split=0.2, epochs =12, verbose=1, batch_size=32,
callbacks=[tensorboard,checkpoint,reduce_lr])
Epoch 1/12
WARNING: All log messages before absl::InitializeLog() is called are written to STDERR I0000 00:00:1732167081.443094 778 service.cc:145] XLA service 0x7e07800a1590 initialized for platform CUDA (this does not guarantee that XLA will be used). Devices: I0000 00:00:1732167081.443153 778 service.cc:153] StreamExecutor device (0): Tesla T4, Compute Capability 7.5 I0000 00:00:1732167081.443157 778 service.cc:153] StreamExecutor device (1): Tesla T4, Compute Capability 7.5 I0000 00:00:1732167138.351220 778 device_compiler.h:188] Compiled cluster using XLA! This line is logged at most once for the lifetime of the process.
516/517 ━━━━━━━━━━━━━━━━━━━━ 0s 126ms/step - accuracy: 0.8518 - loss: 0.4263
WARNING: All log messages before absl::InitializeLog() is called are written to STDERR I0000 00:00:1732167254.390334 778 asm_compiler.cc:369] ptxas warning : Registers are spilled to local memory in function 'input_multiply_reduce_fusion', 160 bytes spill stores, 160 bytes spill loads ptxas warning : Registers are spilled to local memory in function 'input_reduce_fusion', 24 bytes spill stores, 24 bytes spill loads ptxas warning : Registers are spilled to local memory in function 'input_reduce_fusion_2', 24 bytes spill stores, 24 bytes spill loads ptxas warning : Registers are spilled to local memory in function 'input_reduce_fusion_3', 16 bytes spill stores, 16 bytes spill loads
517/517 ━━━━━━━━━━━━━━━━━━━━ 0s 223ms/step - accuracy: 0.8519 - loss: 0.4260 Epoch 1: val_accuracy improved from -inf to 0.93923, saving model to cashew-effnet.keras 517/517 ━━━━━━━━━━━━━━━━━━━━ 221s 252ms/step - accuracy: 0.8520 - loss: 0.4258 - val_accuracy: 0.9392 - val_loss: 0.2125 - learning_rate: 0.0010 Epoch 2/12 516/517 ━━━━━━━━━━━━━━━━━━━━ 0s 128ms/step - accuracy: 0.9540 - loss: 0.1572 Epoch 2: val_accuracy improved from 0.93923 to 0.95182, saving model to cashew-effnet.keras 517/517 ━━━━━━━━━━━━━━━━━━━━ 71s 137ms/step - accuracy: 0.9540 - loss: 0.1572 - val_accuracy: 0.9518 - val_loss: 0.1774 - learning_rate: 0.0010 Epoch 3/12 516/517 ━━━━━━━━━━━━━━━━━━━━ 0s 132ms/step - accuracy: 0.9620 - loss: 0.1243 Epoch 3: val_accuracy improved from 0.95182 to 0.96295, saving model to cashew-effnet.keras 517/517 ━━━━━━━━━━━━━━━━━━━━ 73s 142ms/step - accuracy: 0.9619 - loss: 0.1243 - val_accuracy: 0.9630 - val_loss: 0.1312 - learning_rate: 0.0010 Epoch 4/12 516/517 ━━━━━━━━━━━━━━━━━━━━ 0s 134ms/step - accuracy: 0.9693 - loss: 0.0925 Epoch 4: val_accuracy did not improve from 0.96295 517/517 ━━━━━━━━━━━━━━━━━━━━ 74s 143ms/step - accuracy: 0.9693 - loss: 0.0925 - val_accuracy: 0.9465 - val_loss: 0.1622 - learning_rate: 0.0010 Epoch 5/12 516/517 ━━━━━━━━━━━━━━━━━━━━ 0s 136ms/step - accuracy: 0.9676 - loss: 0.0970 Epoch 5: val_accuracy improved from 0.96295 to 0.97361, saving model to cashew-effnet.keras 517/517 ━━━━━━━━━━━━━━━━━━━━ 75s 146ms/step - accuracy: 0.9676 - loss: 0.0970 - val_accuracy: 0.9736 - val_loss: 0.0861 - learning_rate: 0.0010 Epoch 6/12 516/517 ━━━━━━━━━━━━━━━━━━━━ 0s 137ms/step - accuracy: 0.9775 - loss: 0.0670 Epoch 6: val_accuracy did not improve from 0.97361 517/517 ━━━━━━━━━━━━━━━━━━━━ 75s 145ms/step - accuracy: 0.9775 - loss: 0.0671 - val_accuracy: 0.9487 - val_loss: 0.1748 - learning_rate: 0.0010 Epoch 7/12 516/517 ━━━━━━━━━━━━━━━━━━━━ 0s 137ms/step - accuracy: 0.9720 - loss: 0.0900 Epoch 7: val_accuracy did not improve from 0.97361 Epoch 7: ReduceLROnPlateau reducing learning rate to 0.0003000000142492354. 517/517 ━━━━━━━━━━━━━━━━━━━━ 75s 145ms/step - accuracy: 0.9720 - loss: 0.0900 - val_accuracy: 0.9586 - val_loss: 0.1304 - learning_rate: 0.0010 Epoch 8/12 516/517 ━━━━━━━━━━━━━━━━━━━━ 0s 137ms/step - accuracy: 0.9886 - loss: 0.0352 Epoch 8: val_accuracy improved from 0.97361 to 0.97893, saving model to cashew-effnet.keras 517/517 ━━━━━━━━━━━━━━━━━━━━ 76s 147ms/step - accuracy: 0.9887 - loss: 0.0352 - val_accuracy: 0.9789 - val_loss: 0.0635 - learning_rate: 3.0000e-04 Epoch 9/12 516/517 ━━━━━━━━━━━━━━━━━━━━ 0s 137ms/step - accuracy: 0.9965 - loss: 0.0137 Epoch 9: val_accuracy improved from 0.97893 to 0.98160, saving model to cashew-effnet.keras 517/517 ━━━━━━━━━━━━━━━━━━━━ 76s 147ms/step - accuracy: 0.9965 - loss: 0.0137 - val_accuracy: 0.9816 - val_loss: 0.0708 - learning_rate: 3.0000e-04 Epoch 10/12 516/517 ━━━━━━━━━━━━━━━━━━━━ 0s 137ms/step - accuracy: 0.9976 - loss: 0.0069 Epoch 10: val_accuracy improved from 0.98160 to 0.98184, saving model to cashew-effnet.keras 517/517 ━━━━━━━━━━━━━━━━━━━━ 76s 147ms/step - accuracy: 0.9976 - loss: 0.0069 - val_accuracy: 0.9818 - val_loss: 0.0775 - learning_rate: 3.0000e-04 Epoch 11/12 516/517 ━━━━━━━━━━━━━━━━━━━━ 0s 137ms/step - accuracy: 0.9985 - loss: 0.0047 Epoch 11: val_accuracy did not improve from 0.98184 Epoch 11: ReduceLROnPlateau reducing learning rate to 9.000000427477062e-05. 517/517 ━━━━━━━━━━━━━━━━━━━━ 75s 145ms/step - accuracy: 0.9985 - loss: 0.0047 - val_accuracy: 0.9809 - val_loss: 0.0926 - learning_rate: 3.0000e-04 Epoch 12/12 516/517 ━━━━━━━━━━━━━━━━━━━━ 0s 137ms/step - accuracy: 0.9988 - loss: 0.0033 Epoch 12: val_accuracy improved from 0.98184 to 0.98329, saving model to cashew-effnet.keras 517/517 ━━━━━━━━━━━━━━━━━━━━ 76s 147ms/step - accuracy: 0.9988 - loss: 0.0033 - val_accuracy: 0.9833 - val_loss: 0.0817 - learning_rate: 9.0000e-05
4.7 Performance Analysis¶
epochs = [i for i in range(12)]
fig, ax = plt.subplots(1,2,figsize=(14,7))
train_acc = history.history['accuracy']
train_loss = history.history['loss']
val_acc = history.history['val_accuracy']
val_loss = history.history['val_loss']
fig.text(s='Epochs vs. Training and Validation Accuracy/Loss',size=18,fontweight='bold',
fontname='monospace',color=colors_dark[1],y=1,x=0.28,alpha=0.8)
sns.despine()
ax[0].plot(epochs, train_acc, marker='o',markerfacecolor=colors_green[2],color=colors_green[3],
label = 'Training Accuracy')
ax[0].plot(epochs, val_acc, marker='o',markerfacecolor=colors_red[2],color=colors_red[3],
label = 'Validation Accuracy')
ax[0].legend(frameon=False)
ax[0].set_xlabel('Epochs')
ax[0].set_ylabel('Accuracy')
sns.despine()
ax[1].plot(epochs, train_loss, marker='o',markerfacecolor=colors_green[2],color=colors_green[3],
label ='Training Loss')
ax[1].plot(epochs, val_loss, marker='o',markerfacecolor=colors_red[2],color=colors_red[3],
label = 'Validation Loss')
ax[1].legend(frameon=False)
ax[1].set_xlabel('Epochs')
ax[1].set_ylabel('Training & Validation Loss')
fig.show()
5. Performance analysis on test dataset¶
pred = model.predict(X_test)
pred = np.argmax(pred,axis=1)
y_test_new = np.argmax(y_test,axis=1)
162/162 ━━━━━━━━━━━━━━━━━━━━ 17s 75ms/step
5.1 Classification Report and confusion metrix¶
Classification Report¶
classification_report(y_test_new, pred)- The classification report provides a detailed summary of the model's performance on the test data.
- Metrics included:
- Precision: The proportion of true positive predictions out of all positive predictions.
- Recall (Sensitivity): The proportion of true positive predictions out of actual positives.
- F1-Score: The harmonic mean of precision and recall, balancing both metrics.
- Support: The number of true instances for each class.
Confusion Matrix¶
confusion_matrix(y_test_new, pred)- The confusion matrix gives a tabular representation of the model's predictions versus the actual labels.
- Rows represent the actual classes, and columns represent the predicted classes.
- Components include:
- True Positives (TP): Correctly predicted positive instances.
- True Negatives (TN): Correctly predicted negative instances.
- False Positives (FP): Incorrectly predicted positive instances.
- False Negatives (FN): Incorrectly predicted negative instances.
These metrics allow us to evaluate how well the model classifies the test data, providing insights into strengths and areas for improvement.
cr = classification_report(y_test_new,pred)
cmetric = confusion_matrix(y_test_new,pred)
print(cr)
precision recall f1-score support
0 0.99 0.99 0.99 1296
1 0.96 0.95 0.96 995
2 0.98 0.99 0.99 1453
3 0.99 1.00 1.00 430
4 0.97 0.96 0.96 989
accuracy 0.98 5163
macro avg 0.98 0.98 0.98 5163
weighted avg 0.98 0.98 0.98 5163
5.2 Visualizing the classification report¶
raw_data = cr.split("\n")
precision = []
recall = []
f1_score = []
for line in raw_data[2:-5]:
line = line.split()
precision.append(float(line[1]))
recall.append(float(line[2]))
f1_score.append(float(line[3]))
x = np.arange(len(labels))
width = 0.2
colors = sns.color_palette("Set2", 3)
plt.figure(figsize=(12, 6))
plt.bar(x, precision, width, label='Precision', color=colors[0])
plt.bar(x + width, recall, width, label='Recall', color=colors[1])
plt.bar(x + 2 * width, f1_score, width, label='F1-score', color=colors[2])
for i, v in enumerate(precision):
plt.text(i , v + 0.02, f"{v:.2f}", ha='center', fontsize=9)
for i, v in enumerate(recall):
plt.text(i + width , v + 0.02, f"{v:.2f}", ha='center', fontsize=9)
for i, v in enumerate(f1_score):
plt.text(i + 2 * width , v + 0.02, f"{v:.2f}", ha='center', fontsize=9)
plt.grid(axis='y', linestyle='--', alpha=0.7)
plt.xlabel('Class', fontsize=12)
plt.ylabel('Score', fontsize=12)
plt.title('Precision, Recall, and F1-score by Class', fontsize=14, fontweight='bold')
plt.xticks(x + width, labels, rotation=15, fontsize=10)
plt.axhline(y=1, color='red', linestyle='--', linewidth=1, label='Highest')
plt.legend(loc='upper right', bbox_to_anchor=(1.15, 1), fontsize=10)
plt.tight_layout()
plt.show()
5.3 Visualizing the Confusion Matrix¶
fig, ax = plt.subplots(figsize=(10, 8))
CLASS_NAMES = labels
disp = ConfusionMatrixDisplay(confusion_matrix=cmetric, display_labels=CLASS_NAMES)
disp.plot(ax=ax, cmap='binary')
ax.set_xticks(range(len(CLASS_NAMES)))
ax.set_xticklabels(CLASS_NAMES, rotation=90)
ax.set_yticks(range(len(CLASS_NAMES)))
ax.set_yticklabels(CLASS_NAMES)
for text in ax.texts:
text.set_fontsize(12)
ax.set_xlabel('Predicted label')
ax.set_ylabel('True label')
plt.grid(False)
plt.show()
5.4 Advanced visualizations¶
test_predictions = model.predict(X_test)
162/162 ━━━━━━━━━━━━━━━━━━━━ 5s 29ms/step
y_val_indexes = np.array([val.argmax() for val in tqdm(y_test)])
y_val_indexes[:5]
100%|██████████| 5163/5163 [00:00<00:00, 990676.22it/s]
array([0, 0, 4, 4, 2])
y_test_classes = np.argmax(test_predictions, axis=1)
y_test_classes[0:5]
array([0, 0, 4, 4, 2])
y_test_classes.shape, y_val_indexes.shape
((5163,), (5163,))
y_true = y_val_indexes
accuracy = accuracy_score(y_true, y_test_classes)
accuracy
0.9783071857447221
index = 40
print(test_predictions[index])
print(f"Max value (probability of prediction): {np.max(test_predictions[index])}")
print(f"Sum: {np.sum(test_predictions[index])}")
print(f"Max index: {np.argmax(test_predictions[index])}")
print(f"Predicted label: {CLASS_NAMES[np.argmax(test_predictions[index])]}")
[3.4615238e-10 8.7412627e-10 7.2288189e-09 4.4301933e-07 9.9999952e-01] Max value (probability of prediction): 0.9999995231628418 Sum: 1.0 Max index: 4 Predicted label: leaf miner
CLASS_NAMES[4]
'leaf miner'
# Turn prediction probabilities into their respective label (easier to understand)
def get_pred_label(prediction_probabilities):
"""
Turns an array of prediction probabilities into a label.
"""
return CLASS_NAMES[np.argmax(prediction_probabilities)]
# Get a predicted label based on an array of prediction probabilities
pred_label = get_pred_label(test_predictions[81])
pred_label
'red rust'
def plot_pred(prediction_probabilities, labels, images, n=1):
"""
View the prediction, ground truth and image for sample n
"""
pred_prob, true_label, image = prediction_probabilities[n], CLASS_NAMES[labels[n]], images[n]
# Get the pred label
pred_label = get_pred_label(pred_prob)
# Plot image & remove ticks
plt.imshow(image.astype("uint8"))
plt.xticks([])
plt.yticks([])
# Change the colour of the title depending on if the prediction is right or wrong
if pred_label == true_label:
color = "green"
else:
color = "red"
# Change plot title to be predicted, probability of prediction and truth label
plt.title("{} {:2.0f}% {}".format(pred_label,
np.max(pred_prob)*100,
true_label),
color=color)
plot_pred() : Provides visualization of model Prediction with confidence percentage and actual Class Name / label.
plot_pred(prediction_probabilities=test_predictions,
labels=y_val_indexes,
images=X_test,
n=2400)
plt.show()
def plot_pred_conf(prediction_probabilities, labels, n=1):
"""
Plus the top 10 highest prediction confidences along with the truth label for sample n.
"""
pred_prob, true_label = prediction_probabilities[n], labels[n]
# Get the predicted label
pred_label = get_pred_label(pred_prob)
# Find the top prediction confidence indexes
top_10_pred_indexes = pred_prob.argsort()[::-1]
# print(top_10_pred_indexes)
# Find the top 10 prediction confidence values
top_10_pred_values = pred_prob[top_10_pred_indexes]
# Find the top 10 prediction labels
top_10_pred_labels = top_10_pred_indexes
# Setup plot
top_plot = plt.bar(np.arange(len(top_10_pred_labels)),
top_10_pred_values,
color="grey")
plt.xticks(np.arange(len(top_10_pred_labels)),
labels=top_10_pred_labels,
rotation="vertical")
# Change color of true label
if np.isin(true_label, top_10_pred_labels):
top_plot[np.argmax(top_10_pred_labels == true_label)].set_color("green")
else:
pass
plot_pred_conf(prediction_probabilities = test_predictions,
labels = y_val_indexes,
n = 8)
i_multiplier = 20
num_rows = 10
num_cols = 2
num_images = num_rows * num_cols
plt.figure(figsize=(10 * num_cols, 5 * num_rows))
for i in range(num_images):
plt.subplot(num_rows, 2 * num_cols, 2 * i + 1)
plot_pred(prediction_probabilities=test_predictions,
labels=y_val_indexes,
images=X_test,
n=i+i_multiplier)
plt.subplot(num_rows, 2 * num_cols, 2 * i + 2)
plot_pred_conf(prediction_probabilities=test_predictions,
labels=y_val_indexes,
n=i+i_multiplier)
plt.tight_layout(h_pad=1.0)
plt.show()